# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate data for test."""
import numpy as np

def get_init_params(num_embedding, embedding_dim):
    # Generate initialization parameters
    np.random.seed(42)
    weight_shape = (num_embedding, embedding_dim)
    return {
        "inputs": 10 * np.random.rand(3, 3),
        "weight": 0.01 * np.random.randn(*weight_shape),
    }

def get_golden() -> dict[str, np.ndarray]:
    """Generate golden data for test."""
    output_only = np.array(
        [[[-1.50815332e-02, 1.09964693e-02, -1.77732122e-03, -4.10383288e-03,
           1.17971636e-02, -8.98207910e-03, 8.34795460e-03, 2.96561373e-03,
           -1.03782983e-02, -7.58037437e-04, 9.72963497e-03, 7.95595441e-03,
           1.49543425e-02, 3.38181248e-03, 3.37229632e-02, -9.20390803e-03],
          [-1.44004142e-02, 1.19072730e-02, 1.29939681e-02, -8.67146160e-03,
           6.17640838e-03, 1.21707078e-02, 2.26288266e-03, 8.47401470e-03,
           1.74833008e-03, -1.21685490e-02, 1.04934741e-02, 1.32510569e-02,
           7.34501053e-03, -9.54497233e-03, -7.51179410e-03, -1.13042807e-02],
          [7.67296832e-03, -1.14979260e-02, -7.75336102e-03, 7.73140835e-03,
           -8.01827852e-03, 1.38401575e-02, 1.40520530e-02, 1.39232576e-02,
           -8.80640838e-03, 7.68949511e-04, -4.93432488e-03, 9.23162606e-03,
           1.70660503e-02, 8.73589423e-03, 9.14434568e-05, -3.65539291e-03]],
         [[1.37186209e-02, 1.75553293e-03, -3.09288548e-03, 6.73125498e-03,
           -2.56630173e-03, -3.67825734e-03, 1.27373366e-02, -2.91952677e-03,
           -2.65517607e-02, 3.45517951e-03, -3.95516446e-03, -2.89136847e-03,
           4.52936348e-03, -1.66060904e-03, 2.14938819e-03, -2.02231500e-02],
          [1.22219161e-03, -5.15435683e-03, -6.00253837e-03, 9.47439857e-03,
           2.91033997e-03, -6.35559717e-03, -1.02155218e-02, -1.61755388e-03,
           -5.33648813e-03, -5.52786223e-05, -2.29450455e-03, 3.89348902e-03,
           -1.26511911e-02, 1.09199230e-02, 2.77831312e-02, 1.19363973e-02],
          [1.22219161e-03, -5.15435683e-03, -6.00253837e-03, 9.47439857e-03,
           2.91033997e-03, -6.35559717e-03, -1.02155218e-02, -1.61755388e-03,
           -5.33648813e-03, -5.52786223e-05, -2.29450455e-03, 3.89348902e-03,
           -1.26511911e-02, 1.09199230e-02, 2.77831312e-02, 1.19363973e-02]],
         [[-5.80878137e-03, -5.25169820e-03, -5.71380183e-03, -9.24082845e-03,
           -2.61254907e-02, 9.50369705e-03, 8.16445053e-03, -1.52387600e-02,
           -4.28046053e-03, -7.42406817e-03, -7.03343796e-03, -2.13962067e-02,
           -6.29474968e-03, 5.97720454e-03, 2.55948808e-02, 3.94233037e-03],
          [6.49086712e-03, -1.22287357e-02, 5.36336051e-03, -9.14690923e-03,
           6.20548194e-03, -1.60937372e-03, -3.88264400e-03, -8.85512400e-03,
           -3.56745021e-03, 5.56121813e-03, 1.04386061e-02, 5.26448153e-03,
           1.36388652e-02, 2.53916271e-02, -3.24490969e-03, -2.05866713e-03],
          [-9.43056773e-03, 1.40395872e-02, -1.85508048e-04, -1.67350471e-02,
           -1.07253185e-02, -9.92586184e-03, 1.02347683e-03, -4.32609301e-03,
           -6.59182295e-03, 3.93730443e-05, 4.77754092e-03, -2.59028655e-03,
           -5.74709196e-03, -4.21498204e-03, 3.39820958e-03, -7.38015005e-05]]])
    return {
        "output_only": output_only,
    }

def get_gpu_datas() -> dict[str, np.ndarray]:
    """Generate gpu data for test."""
    output_only = np.array(
        [[[-1.50815332e-02, 1.09964693e-02, -1.77732122e-03, -4.10383288e-03,
           1.17971636e-02, -8.98207910e-03, 8.34795460e-03, 2.96561373e-03,
           -1.03782983e-02, -7.58037437e-04, 9.72963497e-03, 7.95595441e-03,
           1.49543425e-02, 3.38181248e-03, 3.37229632e-02, -9.20390803e-03],
          [-1.44004142e-02, 1.19072730e-02, 1.29939681e-02, -8.67146160e-03,
           6.17640838e-03, 1.21707078e-02, 2.26288266e-03, 8.47401470e-03,
           1.74833008e-03, -1.21685490e-02, 1.04934741e-02, 1.32510569e-02,
           7.34501053e-03, -9.54497233e-03, -7.51179410e-03, -1.13042807e-02],
          [7.67296832e-03, -1.14979260e-02, -7.75336102e-03, 7.73140835e-03,
           -8.01827852e-03, 1.38401575e-02, 1.40520530e-02, 1.39232576e-02,
           -8.80640838e-03, 7.68949511e-04, -4.93432488e-03, 9.23162606e-03,
           1.70660503e-02, 8.73589423e-03, 9.14434568e-05, -3.65539291e-03]],
         [[1.37186209e-02, 1.75553293e-03, -3.09288548e-03, 6.73125498e-03,
           -2.56630173e-03, -3.67825734e-03, 1.27373366e-02, -2.91952677e-03,
           -2.65517607e-02, 3.45517951e-03, -3.95516446e-03, -2.89136847e-03,
           4.52936348e-03, -1.66060904e-03, 2.14938819e-03, -2.02231500e-02],
          [1.22219161e-03, -5.15435683e-03, -6.00253837e-03, 9.47439857e-03,
           2.91033997e-03, -6.35559717e-03, -1.02155218e-02, -1.61755388e-03,
           -5.33648813e-03, -5.52786223e-05, -2.29450455e-03, 3.89348902e-03,
           -1.26511911e-02, 1.09199230e-02, 2.77831312e-02, 1.19363973e-02],
          [1.22219161e-03, -5.15435683e-03, -6.00253837e-03, 9.47439857e-03,
           2.91033997e-03, -6.35559717e-03, -1.02155218e-02, -1.61755388e-03,
           -5.33648813e-03, -5.52786223e-05, -2.29450455e-03, 3.89348902e-03,
           -1.26511911e-02, 1.09199230e-02, 2.77831312e-02, 1.19363973e-02]],
         [[-5.80878137e-03, -5.25169820e-03, -5.71380183e-03, -9.24082845e-03,
           -2.61254907e-02, 9.50369705e-03, 8.16445053e-03, -1.52387600e-02,
           -4.28046053e-03, -7.42406817e-03, -7.03343796e-03, -2.13962067e-02,
           -6.29474968e-03, 5.97720454e-03, 2.55948808e-02, 3.94233037e-03],
          [6.49086712e-03, -1.22287357e-02, 5.36336051e-03, -9.14690923e-03,
           6.20548194e-03, -1.60937372e-03, -3.88264400e-03, -8.85512400e-03,
           -3.56745021e-03, 5.56121813e-03, 1.04386061e-02, 5.26448153e-03,
           1.36388652e-02, 2.53916271e-02, -3.24490969e-03, -2.05866713e-03],
          [-9.43056773e-03, 1.40395872e-02, -1.85508048e-04, -1.67350471e-02,
           -1.07253185e-02, -9.92586184e-03, 1.02347683e-03, -4.32609301e-03,
           -6.59182295e-03, 3.93730443e-05, 4.77754092e-03, -2.59028655e-03,
           -5.74709196e-03, -4.21498204e-03, 3.39820958e-03, -7.38015005e-05]]])
    return {
        "output_only": output_only,
    }

GOLDEN_DATA = get_golden()
GPU_DATA = get_gpu_datas()
